{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "tags": [
     "skip-execution"
    ]
   },
   "source": [
    "# 基于PyTorch训练和部署MNIST图片分类模型\n",
    "\n",
    "PyTorch是一个非常流行的深度学习框架，提供了极高的灵活性和优越的性能，能够与Python丰富的生态无缝结合，被广泛应用于图像分类、语音识别、自然语言处理、推荐、AIGC等领域。本示例中，我们将使用PAI Python SDK，在PAI完成一个PyTorch模型的训练，然后使用训练获得的模型部署推理服务。主要流程包括：\n",
    "\n",
    "- Step1: 安装和配置SDK\n",
    "\n",
    "安装PAI Python SDK，并配置使用的AccessKey、工作空间以及OSS Bucket。\n",
    "\n",
    "- Step2: 准备训练数据\n",
    "\n",
    "我们下载一个MNIST数据集，上传到OSS上供训练作业使用。\n",
    "\n",
    "- Step3: 准备训练脚本\n",
    "\n",
    "我们使用PyTorch示例仓库中的MNIST训练脚本作为模板，在简单修改之后作为训练脚本。\n",
    "\n",
    "- Step4: 提交训练作业\n",
    "\n",
    "使用PAI Python SDK提供的Estimator API，创建一个训练作业，提交到云上执行。\n",
    "\n",
    "- Step5: 部署推理服务\n",
    "\n",
    "将以上训练作业输出的模型，分别使用Processor和镜像部署的方式部署到PAI-EAS，创建在线推理服务。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 费用说明\n",
    "\n",
    "本示例将会使用以下云产品，并产生相应的费用账单：\n",
    "\n",
    "- PAI-DLC：运行训练任务，详细计费说明请参考[PAI-DLC计费说明](https://help.aliyun.com/zh/pai/product-overview/billing-of-dlc)\n",
    "- PAI-EAS：部署推理服务，详细计费说明请参考[PAI-EAS计费说明](https://help.aliyun.com/zh/pai/product-overview/billing-of-eas)\n",
    "- OSS：存储训练任务输出的模型、TensorBoard日志等，详细计费说明请参考[OSS计费概述](https://help.aliyun.com/zh/oss/product-overview/billing-overview)\n",
    "\n",
    "\n",
    "> 通过参与云产品免费试用，使用**指定资源机型**提交训练作业或是部署推理服务，可以免费试用PAI产品，具体请参考[PAI免费试用](https://help.aliyun.com/zh/pai/product-overview/free-quota-for-new-users)。\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step1: 安装和配置SDK\n",
    "\n",
    "我们需要首先安装PAI Python SDK以运行本示例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!python -m pip install --upgrade pai\n",
    "!python -m pip install pandas"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "SDK需要配置访问阿里云服务需要的AccessKey，以及当前使用的工作空间和OSS Bucket。在PAI SDK安装之后，通过在**命令行终端** 中执行以下命令，按照引导配置密钥、工作空间等信息。\n",
    "\n",
    "\n",
    "```shell\n",
    "\n",
    "# 以下命令，请在 命令行终端 中执行.\n",
    "\n",
    "python -m pai.toolkit.config\n",
    "\n",
    "```\n",
    "\n",
    "我们可以通过以下代码验证配置是否已生效。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import pai\n",
    "from pai.session import get_default_session\n",
    "\n",
    "print(pai.__version__)\n",
    "\n",
    "sess = get_default_session()\n",
    "\n",
    "# 获取配置的工作空间信息\n",
    "assert sess.workspace_name is not None\n",
    "print(sess.workspace_name)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step2: 准备训练数据\n",
    "\n",
    "当前示例中，我们将使用MNIST数据集训练一个图片分类模型。为了支持训练作业加载使用，我们需要将数据上传到OSS Bucket上。\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用以下的Shell脚本，我们将MNIST数据集下载到本地目录data。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "%%sh\n",
    "\n",
    "#!/bin/sh\n",
    "set -e\n",
    "\n",
    "url_prefix=\"https://ossci-datasets.s3.amazonaws.com/mnist/\"\n",
    "# 如果以上的地址下载速度较慢，可以使用以下地址\n",
    "# url_prefix=\"http://yann.lecun.com/exdb/mnist/\"\n",
    "\n",
    "mkdir -p data/MNIST/raw/\n",
    "\n",
    "wget ${url_prefix}train-images-idx3-ubyte.gz -P data/MNIST/raw/\n",
    "wget ${url_prefix}train-labels-idx1-ubyte.gz -P data/MNIST/raw\n",
    "wget ${url_prefix}t10k-images-idx3-ubyte.gz -P data/MNIST/raw\n",
    "wget ${url_prefix}t10k-labels-idx1-ubyte.gz -P data/MNIST/raw\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将使用PAI Python SDK提供的OSS上传API，将相应的数据上传到OSS Bucket上。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.common.oss_utils import upload\n",
    "from pai.session import get_default_session\n",
    "\n",
    "sess = get_default_session()\n",
    "data_path = \"./data\"\n",
    "\n",
    "data_uri = upload(data_path, oss_path=\"mnist/data/\", bucket=sess.oss_bucket)\n",
    "\n",
    "print(data_uri)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step3: 准备训练脚本\n",
    "\n",
    "使用PyTorch训练模型，需要我们准备相应的脚本。这里我们以PyTorch官方提供的 [MNIST 示例](https://github.com/pytorch/examples/blob/main/mnist/main.py) 为基础，修改了数据加载和模型保存的逻辑，作为训练脚本。\n",
    "\n",
    "- 使用环境变量获得输入数据路径\n",
    "\n",
    "训练数据将被挂载到训练作业环境中使用，训练代码需要读取指定的路径获取训练数据。\n",
    "\n",
    "\n",
    "```diff\n",
    "\n",
    "-    dataset1 = datasets.MNIST(\"../data\", train=True, download=True, transform=transform)\n",
    "-    dataset2 = datasets.MNIST(\"../data\", train=False, transform=transform)\n",
    "\n",
    "+\t # 使用挂载到训练容器中的数据，默认为 /ml/input/{ChannelName}，可以通过环境变量 `PAI_INPUT_{ChannelNameUpperCase}`\n",
    "+    data_path = os.environ.get(\"PAI_INPUT_TRAIN_DATA\")\n",
    "+    dataset1 = datasets.MNIST(data_path, train=True, download=True, transform=transform)\n",
    "+    dataset2 = datasets.MNIST(data_path, train=False, transform=transform)\n",
    "\n",
    "\n",
    "```\n",
    "\n",
    "- 使用环境变量获取模型的保存路径：\n",
    "\n",
    "用户需要保存模型到工作容器中的指定路径，PAI的训练服务将其才能够持久化保存模型到OSS Bucket上。默认要求用户需要将模型保存到环境变量 `PAI_OUTPUT_MODEL` 指定的路径下（默认为`/ml/output/model`)。\n",
    "\n",
    "\n",
    "```diff\n",
    "\n",
    "-     if args.save_model:\n",
    "-         torch.save(model.state_dict(), \"mnist_cnn.pt\")\n",
    "\n",
    "+     # 保存模型\n",
    "+     save_model(model)\n",
    "+\n",
    "+\n",
    "+ def save_model(model):\n",
    "+     \"\"\"将模型转为TorchScript，保存到指定路径.\"\"\"\n",
    "\n",
    "+     output_model_path = os.environ.get(\"PAI_OUTPUT_MODEL\")\n",
    "+     os.makedirs(output_model_path, exist_ok=True)\n",
    "+\n",
    "+     m = torch.jit.script(model)\n",
    "+     m.save(os.path.join(output_model_path, \"mnist_cnn.pt\"))\n",
    "\n",
    "```\n",
    "\n",
    "PAI提供的预置[PyTorch Processor](https://help.aliyun.com/document_detail/470458.html) 在创建服务时，要求输入的模型是[TorchScript 格式](https://pytorch.org/docs/stable/jit.html) 。在本示例中，我们将模型导出为 `TorchScript格式` ，然后分别使用 `PyTorch Processor` 和镜像方式创建推理服务。\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "运行以下代码，创建一个训练脚本目录。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "!mkdir -p train_src"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将训练作业脚本保存到`train_src`训练脚本目录，完整的作业脚本如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "%%writefile train_src/train.py\n",
    "\n",
    "# source: https://github.com/pytorch/examples/blob/main/mnist/main.py\n",
    "from __future__ import print_function\n",
    "\n",
    "import argparse\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
    "        self.dropout1 = nn.Dropout(0.25)\n",
    "        self.dropout2 = nn.Dropout(0.5)\n",
    "        self.fc1 = nn.Linear(9216, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        return output\n",
    "\n",
    "\n",
    "def train(args, model, device, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_idx % args.log_interval == 0:\n",
    "            print(\n",
    "                \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n",
    "                    epoch,\n",
    "                    batch_idx * len(data),\n",
    "                    len(train_loader.dataset),\n",
    "                    100.0 * batch_idx / len(train_loader),\n",
    "                    loss.item(),\n",
    "                )\n",
    "            )\n",
    "            if args.dry_run:\n",
    "                break\n",
    "\n",
    "\n",
    "def test(model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(\n",
    "                output, target, reduction=\"sum\"\n",
    "            ).item()  # sum up batch loss\n",
    "            pred = output.argmax(\n",
    "                dim=1, keepdim=True\n",
    "            )  # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print(\n",
    "        \"\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n\".format(\n",
    "            test_loss,\n",
    "            correct,\n",
    "            len(test_loader.dataset),\n",
    "            100.0 * correct / len(test_loader.dataset),\n",
    "        )\n",
    "    )\n",
    "\n",
    "\n",
    "def main():\n",
    "    # Training settings\n",
    "    parser = argparse.ArgumentParser(description=\"PyTorch MNIST Example\")\n",
    "    parser.add_argument(\n",
    "        \"--batch-size\",\n",
    "        type=int,\n",
    "        default=64,\n",
    "        metavar=\"N\",\n",
    "        help=\"input batch size for training (default: 64)\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--test-batch-size\",\n",
    "        type=int,\n",
    "        default=1000,\n",
    "        metavar=\"N\",\n",
    "        help=\"input batch size for testing (default: 1000)\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--epochs\",\n",
    "        type=int,\n",
    "        default=14,\n",
    "        metavar=\"N\",\n",
    "        help=\"number of epochs to train (default: 14)\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--lr\",\n",
    "        type=float,\n",
    "        default=1.0,\n",
    "        metavar=\"LR\",\n",
    "        help=\"learning rate (default: 1.0)\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--gamma\",\n",
    "        type=float,\n",
    "        default=0.7,\n",
    "        metavar=\"M\",\n",
    "        help=\"Learning rate step gamma (default: 0.7)\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--no-cuda\", action=\"store_true\", default=False, help=\"disables CUDA training\"\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--dry-run\",\n",
    "        action=\"store_true\",\n",
    "        default=False,\n",
    "        help=\"quickly check a single pass\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--seed\", type=int, default=1, metavar=\"S\", help=\"random seed (default: 1)\"\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--log-interval\",\n",
    "        type=int,\n",
    "        default=10,\n",
    "        metavar=\"N\",\n",
    "        help=\"how many batches to wait before logging training status\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--save-model\",\n",
    "        action=\"store_true\",\n",
    "        default=False,\n",
    "        help=\"For Saving the current Model\",\n",
    "    )\n",
    "    args = parser.parse_args()\n",
    "    use_cuda = not args.no_cuda and torch.cuda.is_available()\n",
    "\n",
    "    torch.manual_seed(args.seed)\n",
    "\n",
    "    device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "    train_kwargs = {\"batch_size\": args.batch_size}\n",
    "    test_kwargs = {\"batch_size\": args.test_batch_size}\n",
    "    if use_cuda:\n",
    "        cuda_kwargs = {\"num_workers\": 1, \"pin_memory\": True, \"shuffle\": True}\n",
    "        train_kwargs.update(cuda_kwargs)\n",
    "        test_kwargs.update(cuda_kwargs)\n",
    "\n",
    "    transform = transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    "    )\n",
    "\n",
    "    data_path = os.environ.get(\"PAI_INPUT_TRAIN_DATA\", \"../data\")\n",
    "    dataset1 = datasets.MNIST(data_path, train=True, download=True, transform=transform)\n",
    "    dataset2 = datasets.MNIST(data_path, train=False, transform=transform)\n",
    "    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)\n",
    "    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)\n",
    "\n",
    "    model = Net().to(device)\n",
    "    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)\n",
    "\n",
    "    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)\n",
    "    for epoch in range(1, args.epochs + 1):\n",
    "        train(args, model, device, train_loader, optimizer, epoch)\n",
    "        test(model, device, test_loader)\n",
    "        scheduler.step()\n",
    "\n",
    "    # 保存模型\n",
    "    save_model(model)\n",
    "\n",
    "\n",
    "def save_model(model):\n",
    "    \"\"\"将模型转为TorchScript，保存到指定路径.\"\"\"\n",
    "    output_model_path = os.environ.get(\"PAI_OUTPUT_MODEL\", \"./model/\")\n",
    "    os.makedirs(output_model_path, exist_ok=True)\n",
    "\n",
    "    m = torch.jit.script(model)\n",
    "    m.save(os.path.join(output_model_path, \"mnist_cnn.pt\"))\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step4: 提交训练作业\n",
    "\n",
    "`Estimator`支持用户使用本地的训练脚本，以指定的镜像在云上执行训练作业。通过`Estimator`，我们将以上准备的训练作业脚本提交到PAI，使用PAI提供的PyTorch镜像执行训练任务。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from pai.estimator import Estimator\n",
    "from pai.image import retrieve\n",
    "\n",
    "\n",
    "# 使用PAI提供的PyTorch的GPU训练镜像\n",
    "image_uri = retrieve(\n",
    "    \"PyTorch\",\n",
    "    framework_version=\"1.8PAI\",\n",
    "    accelerator_type=\"GPU\",\n",
    ").image_uri\n",
    "\n",
    "print(image_uri)\n",
    "\n",
    "\n",
    "# 配置训练作业\n",
    "est = Estimator(\n",
    "    # 训练作业启动命令\n",
    "    command=\"python train.py --epochs 5 --batch-size 256 --lr 0.5\",\n",
    "    # 需要上传的代码文件\n",
    "    source_dir=\"./train_src/\",\n",
    "    # 训练作业镜像\n",
    "    image_uri=image_uri,\n",
    "    # 机器配置\n",
    "    # PAI的训练服务支持机器实例类型请见文档：[公共资源组实例和定价](https://help.aliyun.com/document_detail/171758.html?#section-55y-4tq-84y)\n",
    "    instance_type=\"ecs.gn6i-c4g1.xlarge\",  # 4vCPU 15GB 1*NVIDIA T4\n",
    "    # 训练作业的Metric捕获配置\n",
    "    # 训练服务支持从训练作业输出日志中（训练脚本打印的标准输出和标准错误输出），以正则表达式匹配的方式捕获训练作业Metrics信息。\n",
    "    metric_definitions=[\n",
    "        {\n",
    "            \"Name\": \"loss\",\n",
    "            \"Regex\": r\".*loss=([-+]?[0-9]*.?[0-9]+(?:[eE][-+]?[0-9]+)?).*\",\n",
    "        },\n",
    "    ],\n",
    "    base_job_name=\"pytorch_mnist\",\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "`estimator.fit`方法将用户的训练作业提交到PAI上执行。任务提交之后，SDK会打印作业详情页链接和训练作业的日志，等待作业执行结束。\n",
    "\n",
    "当用户需要直接使用OSS上数据，可以通过`estimator.fit`方法的`inputs`参数传递。通过`inputs`传递数据存储路径会被挂载到目录下，用户的训练脚本可以通过读取本地文件的方式加载数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用.fit方法提交训练作业\n",
    "est.fit(\n",
    "    inputs={\n",
    "        # 训练作业的输入数据，每一个Key，Value对是一个Channel，用户可以通过环境变量PAI_INPUT_{ChannelNameUpperCase}获取对应的数据路径\n",
    "        # 例如以下的train_data，训练的脚本中可以通过`PAI_INPUT_TRAIN_DATA`获取数据挂载后的路径.\n",
    "        \"train_data\": data_uri,\n",
    "    }\n",
    ")\n",
    "\n",
    "# 训练作业产出的模型路径\n",
    "print(\"TrainingJob output model data:\")\n",
    "print(est.model_data())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step5: 部署推理服务\n",
    "\n",
    "在训练作业结束之后，我们可以使用`estimator.model_data()`方法拿到训练作业产出模型的OSS路径。下面的流程中，我们将训练产出的模型部署到PAI创建在线推理服务。\n",
    "\n",
    "部署推理服务的主要流程包括：\n",
    "\n",
    "- 通过`InferenceSpec`描述如何使用模型构建推理服务\n",
    "\n",
    "用户可以选择使用Processor或是自定义镜像的模式进行模型部署。以下示例中将分别使用两种方式部署获得的PyTorch模型。\n",
    "\n",
    "- 通过`Model.deploy`方法，配置服务的使用资源，服务名称，等信息，创建推理服务。\n",
    "\n",
    "对于部署推理服务的详细介绍，可以见: [文档:部署推理服务](https://pai-sdk.oss-cn-shanghai.aliyuncs.com/pai/doc/latest/user-guide/model.html)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Processor 模式部署\n",
    "\n",
    "[Processor](https://help.aliyun.com/document_detail/111029.html) 是PAI对于推理服务程序包的抽象描述，他负责加载模型并启动模型推理服务。模型推理服务会暴露API支持用户进行调用。\n",
    "\n",
    "PAI提供了预置[PyTorch Processor](https://help.aliyun.com/document_detail/470458.html)，支持用户方便地将TorchScript格式的模型部署到PAI，创建推理服务。\n",
    "\n",
    "以下示例代码中，我们通过PyTorch Processor将训练产出的模型部署为一个推理服务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pai.model import Model, InferenceSpec\n",
    "from pai.predictor import Predictor\n",
    "from pai.common.utils import random_str\n",
    "\n",
    "\n",
    "m = Model(\n",
    "    model_data=est.model_data(),\n",
    "    # 使用PAI提供的PyTorch Processor\n",
    "    inference_spec=InferenceSpec(processor=\"pytorch_cpu_1.10\"),\n",
    ")\n",
    "\n",
    "p: Predictor = m.deploy(\n",
    "    service_name=\"tutorial_pt_mnist_proc_{}\".format(random_str(6)),\n",
    "    instance_type=\"ecs.c6.xlarge\",\n",
    ")\n",
    "\n",
    "print(p.service_name)\n",
    "print(p.service_status)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "`Model.deploy`返回的`Predictor`对象指向创建的推理服务，可以通过`Predictor.predict`方法发送预测请求给到服务，拿到预测结果。\n",
    "\n",
    "我们使用`numpy`构建了一个测试样本数据，发送给推理服务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# # 以上保存TorchScript模型要求输入为 Float32, 数据格式格式的形状为 (BatchSize, Channel, Height, Width)\n",
    "dummy_input = np.random.rand(2, 1, 28, 28).astype(np.float32)\n",
    "\n",
    "# np.random.rand(1, 1, 28, 28).dtype\n",
    "res = p.predict(dummy_input)\n",
    "print(res)\n",
    "\n",
    "print(np.argmax(res, 1))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "在测试完成之后，可以通过`Predictor.delete_service`删除推理服务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "p.delete_service()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 镜像部署\n",
    "\n",
    "Processor模式启动的推理服务性能优越，适合于对于性能较为敏感的场景。对于一些需要灵活自定义的场景，例如模型使用了一些第三方的依赖，或是推理服务需要有前处理和后处理，用户可以通过镜像部署的方式实现。\n",
    "\n",
    "SDK提供了`pai.model.container_serving_spec()`方法，支持用户使用本地的推理服务代码配合PAI提供的基础镜像的方式创建推理服务。\n",
    "\n",
    "在使用镜像部署之前，我们需要准备模型服务的代码，负责加载模型、拉起HTTP Server、处理用户的推理请求。我们将使用Flask编写一个模型服务的代码，示例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 准备推理代码保存目录\n",
    "!mkdir -p infer_src"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile infer_src/run.py\n",
    "\n",
    "\n",
    "import json\n",
    "from flask import Flask, request\n",
    "from PIL import Image\n",
    "import os\n",
    "import torch\n",
    "import torchvision.transforms as transforms\n",
    "import numpy as np\n",
    "import io\n",
    "\n",
    "app = Flask(__name__)\n",
    "# 用户指定模型，默认会被加载到当前路径下。 \n",
    "MODEL_PATH = \"/eas/workspace/model/\"\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = torch.jit.load(os.path.join(MODEL_PATH, \"mnist_cnn.pt\"), map_location=device).to(device)\n",
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    ")\n",
    "\n",
    "\n",
    "@app.route(\"/\", methods=[\"POST\"])\n",
    "def predict():\n",
    "    # 预处理图片数据\n",
    "    im = Image.open(io.BytesIO(request.data))\n",
    "    input_tensor = transform(im).to(device)\n",
    "    input_tensor.unsqueeze_(0)\n",
    "    # 使用模型进行推理\n",
    "    output_tensor = model(input_tensor)\n",
    "    pred_res =output_tensor.detach().cpu().numpy()[0] \n",
    "\n",
    "    return json.dumps(pred_res.tolist())\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    app.run(host=\"0.0.0.0\", port=int(os.environ.get(\"LISTENING_PORT\", 8000)))\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "通过`pai.model.container_serving_spec`，我们基于本地脚本和PAI提供的`PyTorch`镜像创建了一个`InferenceSpec`对象。\n",
    "\n",
    "- 模型服务的代码和启动命令：\n",
    "  \n",
    "用户指定的本地脚本目录source_dir会被上传到OSS，然后挂载到服务容器（默认到 /ml/usercode目录）。\n",
    "\n",
    "- 推理服务镜像：\n",
    "\n",
    "PAI 提供了基础的推理镜像支持用户使用，用户可以通过`pai.image.retrieve`方法，指定参数`image_scope=ImageScope.INFERENCE`获取PAI提供的推理镜像。\n",
    "\n",
    "- 模型服务的第三方依赖包：\n",
    "\n",
    "模型服务代码或是模型的依赖，可以通过`requirements`参数指定，相应的依赖会在服务程序启动前被安装到环境中。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.model import InferenceSpec, container_serving_spec\n",
    "from pai.image import retrieve, ImageScope\n",
    "\n",
    "torch_image_uri = retrieve(\n",
    "    framework_name=\"pytorch\", framework_version=\"1.12\", accelerator_type=\"CPU\"\n",
    ").image_uri\n",
    "\n",
    "inf_spec = container_serving_spec(\n",
    "    command=\"python run.py\",\n",
    "    source_dir=\"./infer_src/\",\n",
    "    image_uri=torch_image_uri,\n",
    "    requirements=[\n",
    "        \"flask==2.2.2\",\n",
    "        \"Werkzeug==2.3.7\",\n",
    "    ],\n",
    ")\n",
    "print(inf_spec.to_dict())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用训练作业输出的模型，以及以上的 InferenceSpec，我们将通过 Model.deploy API部署一个在线推理服务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.model import Model\n",
    "from pai.common.utils import random_str\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "m = Model(\n",
    "    model_data=est.model_data(),\n",
    "    inference_spec=inf_spec,\n",
    ")\n",
    "\n",
    "predictor = m.deploy(\n",
    "    service_name=\"torch_mnist_script_container_{}\".format(random_str(6)),\n",
    "    instance_type=\"ecs.c6.xlarge\",\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "我们准备一张 MNIST 测试图片，用于发送给到推理服务。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAAAAAB5Gfe6AAAxyUlEQVR4nO196XrjuM4mwE2rnfTpmfu/wDldldjauWB+kJSoxYmdcqrOfHPQT1ccR6JICATAFyCI8BTC8B8AEBCQ//DI/dvbEREZIgAQOSI6ag33T0FEhHAbERzetiL2QC//R5L40x34iMi/44O3iE97xn80Ax6bRl+j/++nwP/3DHjiFKDVj5luTtedeN/4IrUAuPn563PkaQygw07hwad4Aa67T9EMLt9gNKeRBRjbwdgO7nlAD2nIpyrBG68D538+vza9AiMHyF+OCwe8B4AHrczf3ikbT2KAf+z8Cil2MvzYjB5jN2dPhmA1Goythh80f4/pLCCIPFp35hEReJ4EEC4dXhPuXv/BRQv/okuItG7QcxLntiL3tiyYx0/3SMETdcB+EieTdqcD9tduu5u+3FmkcGGAb+W2CNw3B36PDsCv6AAgmnVdbAkTBvg/HYw/mYaf03+2J/j9juB/HaGnSsBHU+/pL/PAPZo//oEpgJSYp7kTyc+1EqS9il60++bO+NelheWCWZ3OTyec7cBdLHiaBCS2LnXmDlGLjX2DdBCzYdsMH1YCRrEVineujM3Moc89gicxIHHSoiuGoQMH7+FQQpcRHPV6lqvZPV4Zh/S+lLmfC8ETdQBuejAT7T322yy4YdiWOzFK0CwLiDMPILCI4qWf0tN0wG7dM7u4dOSy7/ze3Tpv8/dZpChlAMYpEJ3keAHEH5/R90vAXQ5pCop+1HTyVmcRDwzEVILutzn/2Y7QAS3qDyAdPN6cgR/Tfx2hP/NYvPH5c5olIG1hJQH79eFH9MTl8Pxz7+HQ0Sj3Uz7qrh18sHtI+tdd08EZ+r2rwbWpWzs4n6h4jGjIHkfYf7pFEXYhpNjifYLwVFA0vNNgoFe9vv3L4sBgtOx09Of5KXvCFHXCIAF3Tq1nQWI4I0I7F/YIEtws6pOf97y19SNWfgQGmCRAbr/VE6SDj7fGvYF1wgSITuyGZeny6h64DVNv/DN6ng6IDzz2cNfu3qEEHMvP9jFrxG21QpyXXRGg/c064NfX/L94f3L7ofd9SP91hP50B75KW5R5/vigGH0HA450z61uzes2SjXnSmMcAACzzvexgzV+Tn8oMBJp7+XuFPLGXuycv90Ath7WCgHa/vExIXg6A/ZW/8i8rzpNAAnEmTo1m4vWj9lghBvP5971wFPxgOX9rGiRgH2nCPZ/uOU8JLhrwD+S8NlBf+gwcrihp8Pie8f1Ns6xXLNtZw17eC2xYjAukjaPMj7soeXlNyrB2cM9dsqOWZLk28W2aH9DdP4De2gdWb53GQAAv88M/qqPtKV5+X+gYR6SgP86Qn/u0Ycv6n4wJwVCfkG+voMBR5p9QytjuVq34KLYN2jKSqXPSpAOVj640UIf0dMZQJ9BWAAJCr5aHibT2YNEuFKCmCxxMP7q1SGt2PMQNvgtU+Cz5+L834pZOI+eCJCWgOvOiYivHxPoa+aAT6OZsbFP6OmACN4zBRBCTjcmo/Mh0RjW8BiTvzh+3rhLHgPcPH/+852o6DdNgQ+fjeBRbMQVA2bhj/8eiPJmUMnoEwZFJ/F3g6JLZz67AH04E7eTwHf70DW81fQsBan8/9nV4H2EeMiAu+H8tKlf6sh/HaHvaviWckq/oLUOIKBknwtt70n24iTr49Tx/+CZN+kbGLAVyT0+S0CLFfDfBAasLiRarqBVG0T+ZkzSBWCbfHu8z+iz3v4yHa1H6fCV4GyqvuLJIoMlb9bjAiu4iLw78blC/sKzP2two9kI79m99YUH4bIsxgU2SkXmjsc+mwEY/bz511m6v4G8OSXvMqa4SOTHb/UEQ58gtcRRwz1u3O4iCq+eZrGbp9S9D/wOJbjFRT/zDH+F4oAXji/5M3/GFfa9eMwb292PiPdq8XkVkSwdH5lwfzIydMAjRMYYY4hA5Jxz1q3+GLYLOFhbzF+Rr9/KAJx197wYCr47zn/yDAByzhprrJ1NKAIyYIAIROTAUWTDr02v718MBQfP/+K3RLMwTMb8Z0QW1wYYBMBZrfWktXEU96KgvxqAHFmyzjkKnuN29fgIff9ymAgoCjIiY8gZ55xzLsIHxhnjnh8Q2ANk9TSMwzBq65yf4YiMB+mwzlhrjXPOOXKJZ+0fApsvPqTfIgHzREbGGOdcCiGlkEIKT5wLzweMm+adHoeu67p+NNY58BOACc45Q3DWaqONsdZa66K3Q2vDQ/BnIDHabWNKNzagH75USmUyk0oqqaSUQkghOOc8qgaw09A2zbVph8lYzwDGmBSCcwRntNbTpI0xDC3hAiVFCCl05R7/4/kSQHPaW/h9gXMY40IIJVWW5XmWqyzLVKZUJqWUQgjBWQAKyYxd837JlOxHbS14BgghpeCMnNbjNIpJa2YAyAVMdMYEZ6D4Hq3wDTrgeEcnAHLOhZQyU1meF0Ve5HmeZ3mWZZlUUgrpGQAASHpoL3kmOFfDZC0F4VFSCcHI6nEcBOecMSRyzAEELCk8aN6W+WdR4Q2Wi1wIIZVUuSryoiyKssgDF5TKAgOi9Oq+UZwBMqkmExkglPIMmMZBeoVIzlnE7UP/QyCxNYbLuOBSqkxleVYURVkWZVEURVHkeZ4pJaWfAuFWrTg4Zx1xuWaAFIzMNAjuzSY5Z9mvOJ2/wxFCQATGhZRKZXmWFVlRFHH4ZZHnWaakFJwn0E4GzkxaWxIHEmBGKThDZAycteY7GbBv+m6QJbmQIWNcSKXyrCjyvMiKvCjyvAgSkOVKCs7XbSiTl8Y4zIrJWiJAYFwoKSVnpCclOEPGGThrNWMxx/YrnBAfuA0roxIu2QagjiHslfVB5JxLqbK8KMuyKIssz7M8z/KsyIuiyLJMiV3XkcvcWBDFqK0lAgDGuZRScCQ9KSk444yRs0ZwXCEB+BgfPpsCK4TrHsx6RoQS68elUFlWFFVd1WVVZHmmlFJZluV+AsSp74AAgkSjUIVDNU3GWiICQM6l4IIxmkYZHCeyWk98PQcI6ebLeYgBOP+fxl8CDzC5irb3rR+PyIV//XV9Op+qqswyKaWQUmVZnuWZDNcba8kBMsY5A+tQOhCZMdZZIhfUKGcMSXsdwBCcnibB90rgAUREfMCsZfhhVt9hWDHeskgL41LleVlWp5fzy/lUlZnyfp9UKssyFfo8TVo7AuCcC8bIEZMocme9v08AjHPGGCJNinOGiERmHIYdA9J9VXdDYkuEcnVP1AIz0Jj8+QbnEJaQdRi/kFleVtXp/Pr6ej7XpZKcIzLGpVIydIDGYRwn64KkcwQAzhSBIyK/4mHM2z4aJWNA5JwdM9/WugsUUeF7aDUFjhCKNbxzR6Mx4hU5wLlUeVHV9enl5fWvl3NdSsn88pbL2fSPfd8Po3GITEjJBWeMc86QAfhiYgB+tQhAggM4a43x42czA+YP9yKiWwY8i+JWTgA/AbK8rKr6dDqfz+dzXUqBQESEjIe4txv7vuv6wThA7leKXEpAJjhD9Kt+QM82IAbO6mnyzkMy/i/R9zlCAbMXQqqi9MM/neq6rgohAJxzDtCDW0R66Luu9QxgQkqhpFSZA8aR7eBFlNYvnfajvz9LfqaUAUtI4Qnk14SIXGZ5UdWnl9P5fK7rsigyyQCsV+5E5Ky1eujarutH7QA4l0IopfIyN46A9j5CWAcxXCuiGGt+BBRfMSBZTEVaL+zvi+/QevxCqqKqzy+vp/PL+VQWmVIIABwNWWetMXrSZhy6ruuHwTgCxgWXSuVFVU26tEqKnZQyRFyCqXOvkCHQnD57rxmcube4D4tj5e1ZYlk+1ywUdysRBgc+L6rTy+vr+XQ+VWWugtlHBLLTOI5jP45j33f9MIzGkl81q6yo6tM4GZ1nUoqNn0wUMWPnkcHQop8PngO4wYhuMmD+FOVng+ekBR1mhPa22x0TPLw8Ms5lllf1+fWv13Ptxx8HQ85MQ9d1bdsNQ9/3wzBNxjpCxphQeVmfh1EbrU1mpeSrPAZyzjobyLmICiIizEl294ZGohle/L44iFgcKvF/N6zZf5fwBgGQMSGzrCjr88vr67mqizxTcfzOWT12zeV6vbRdN/TDOE3aWEeAyHiWV+dh0tZaY61zJEXCAXLOWq211tp4cDiOP3pgX5kC4d/0tsXvOSzrcUQzLIWAiIyLLC/K+nQ+v7yeyjJfpNkaPY1De3l/e3u7Nn0/DJM2xljnEAB5VtSDtiH+D4QAanmItSagglobY92CubMosb+mBGH1yqMeuNezwFgPFRnnKi/KqjqdzqfTqSgUD24rWaPHYWjb69uPHz8v124YRg/zOgcAIFQ5WeRSKhmAc24XPeDHP07TGEUgjN/DQY/ZwpUZPPRuH7OKiAjIGGcMOecyy+u6PtV1XddVnguP+ZEzdtJD37bXy/vbzx/v124YJ2Osi6MRmUWRl8M46SAXC/oPZI3R0ziN4zRqY2wadXm4v9+xZwgZY0IwLoRUeXF6OZ/qqiqLPFP+7ZM1Wo9j3zfXy+VyuVyu126YJuPsPEyLYgpD29TOAgJrrZ6mcRiGYQxq4xfocwnwX6Xr3w9wAQIEZFwIKYVSWZ5X9ctfL+fTMn7nrJnGse+7tmnef769X5um7cdJrwZChMg4F1JIvwZe8p+ds3qahr5r266PYYMt3T8NEjO46IADdR+jDiESfdvAEjDGhcpUVuRFWVan+vWvl7oqlGQAAOSMGYe+a9umubaX97f3S9sN4zSZ9JlefFSWZXmmpBBsWfESWaOnoeuapmnafvCQYXgxQWshHhSu+ZQBKd18wwjraXaM/wuV5UVZVXVV13V1enk5ldH+WaOnvmuul+vlcu3a6/XStMOkwyji+FWW53lRlmVZlrnyEGh8orNGj33XNJfLtenGybhFPqOjggdL+zsYMCdZ7+9db+hGIMBbfjEyLlVeVrVf+1V5Vdd1kUkAAHBWj2N3fX97e3u7XPu+69q2n7RNx8+5yIqyqqqqruu6ypQUnM1ugDcgXXu9XC5NO4yTcdGJ8XuHaIYF70JwBMSRxcyWj2VnbvkmMS5Ulpf1y+tff72ez4XKi6IIuBcZPY19e3378c+PH2/XaRyHYRzS6c8Y51LmZV2fTqf6VNeVlIKLGfQgZ800Dl3bXC+Xpp/0ZOz86NlTxbvGnjBgO8qb90ZBuN0451JlRVmdXv76++9/vZxzIZXKPPBDeprGvmsub//8n3//87PR2mitzfL+ffRMqaKsT6fT+VTXVcU5ZyxIAJG1xrfRXK/XtrfGWhfd9TnVPmRcf09oLPiMN5WEUCoviqo+vfz1r7//fj1njHMuOAKAM9M49l3bXN5//vPvf/+8WuucdZbA+QYZE1xKledFFcGDsmAhewKAwJE10zQMvdeB3eicc4SY6iaaO/l8RGhBx262rbKsKMq6Pp3O55fXv15Paq6Dab31b6/Xy/vb28+fPxrnkzmjRCEyJqTKi6Kq6vpUV1VR5HlcoxA4sE5P4zD0fd91Xdf308o7/ZI/8BRHKMYCEGSWl1Vdnzz+c6qqMjKMnJ6moe+769U7P9d2WO5flvRCqrws67ouyyLPgvEEAADnXYi+6zx4MI6TjY//Oj3CALwhASGvgzFksijKU12fT+eX86kucxUud9bZaeyHrm2v17f3S9N0wzg3EBeWRIBc5UV1Op1OdVXmKgVDyFqth75tGu8CTdrChujRuFDKgA8Tl5ehBxO7vszn/QjOs6KqTqfTuT6dz6cyU755Z62xph+6rm2ay/Xt7XLthmnvwTkExmVenc7nl/O5KjK1XgWbcezb5nK5tl0/Luo/gWvgoaXwmgHzkm8PfeFsKIMVWMMmyDgXQkolZVbV9fl0PtXV6XwqMw7gFy9Wm6nr2uZ6vV6v7+/vTT/tXh8AEKFQeXV6eT2/nOsyk0nEGMhOY9c2l8v75dptfOD9a783QVssY5+V6MGNcZNb2I+0jp4gMiEzleWZKurT6eV8OtVFWVe5RMfJWmOM1tPYtnH2Xy7tYI5ceEfIs7I+v76eX851kckUCyMzDe318n55vzTtMJnjRVB0hu6lKAGJ8djK0LIOQsQgJriqec+4UJkP959OLy/n07nM8iLPOFgiY4zR0zQOzfXy/v52uV7btu3Wvv9CTKiiPr/+dT7VdZGlMwCsmfr28v7+7mfQrVXgjFQdv8pbDJhvPfTuUy8YNhrAB/+981pW55fz6/lcl1IqJRhYE9CbYeyvl7e3t5+Xa9P1wzjZY1+CiayoTueXc10V+WoGeAjt+v72/n5p+0Hb3e1RF3xJAlKc8yBLGQPQs10Mgc9ekh7Frev6/PL68nKqc8E5R3LWGq21Hqd+6C6Xt58/f16uzTBNxjjg+yEAcpnlZX0+n6silysGWKOHvr1e3t8v16afzMHd9MFvnzNgRlPvZx+SB/5C9Pt0Op1eXl5ez3WdcUR0xlk96WmaxqH30vv+dmnaSRtHAMh2s80HEsuyqqsqUyLFgskvAa7X6/Xa9DdVwGIJ7qRfc4SQIWMh+aOq6vPpdD6f6rouFQAAWrDGTNMw9kPbtZfL5XptmqafjCNEnwmRaBIAYCGTpCiKLJOMLW8F9BhhhKbt+/FQAr5Cex3wiCOBjHGufPJLVZ9Op9Oprssi9xAuQyBn9Dj0Xde0zfVyvTY+/gNAMXo8+4A+jiakypRSSinJGYDz2waIaBr6ruvatm27vh8nY9zejHzgyNzDgNkf+/SmGI4KyU95CP6eTue6rooi5jwA+tVr3zbN9dpem0vT9eNkQmLj9mEMmRBKehKcMwyHbDlHzk5d2zZt27Zd1w3jpNdI2AqieWwKLKpvse7b+1OnZ/UM9JGPGP09n+q6LFL/HZzVY99cLu/Xpm2bth+XntN6/IicSyWVisNHIOecIx8EGjwG1nZdP4xTEg2AFT7x8LyIGSIEsyu4pRX84aUsXhQcgCj+p1NVFdH/BfBZHHrsm/e3t0vTDV03pPjFZkXBuFRKKSkFZ8xHP521zlljjNEzA/p+HCMW7P3yVS31Bz3hZQosNx7dnpbqTKYAegNY1TH8X5aJ/+6cc9booWsubz/f2mEcx3Ftvtf+tJBZlikpQtYPOWdt9COnvr1cr03Tdf04TsY6R9sw9s0zTj5mwD06Lzn/xPMpRuOYkFkeBKA+neqqKBLjbYwxRo9D21zefr41k9Faz9Zr41ExIaTKiiLPMsmZTwtx/u3raZqGsWsvl6bxE0BbHxDEVf++Mvy1ErxRC37DAljWC8i5T/+o6/p0qquyLHIejTf5+N009N31cnl7a7WzdtkFhZiOXwRnoihypTwLCZy11phpmoZh6Nv2cr02be8hdJ8zs5WAx4e/8QM+ujtFXuaIPJcyy4sA31ZlmWcqAtg06WkYhqHvu7Ztrtdra5NshrC0ZkTkbamQSuZFVZVFrqTg6DEUZ63W09D3Xde012vr378xfkl+u4dfZcCj5NMfyrKqqqosiyLPlAx9cWb04G3TBOPV7/1rZACIyIWPg+RFdarLIg9pckTknDF6GvqubZr22rRdP0xaezv6SzjQQh9nih7bw/gSkYtZAKqqyLOQ9efIkdHDMHRdc72+X69N2w0jrRsiRETGOBN+C5FUKi/K0/lUFV4LeCPirAfSm2t7bdp+mHTEAT7y2B6YDcew+ObDPG4MmpAA/B4WlRVVdapPdV0WmRKCAYBz1lo9RtF/+/l+bfvBbB9CBIz7DSRKSqmUyvKyPr+eylyFOAB5R2ro2uvl2rRN26Ue0Hy0043u3xccixkim/u3lj/oP695ggRgcAKCAswzv3ghZ4zRU99er9fr5dJc3t+v3TAdPNsBk0VRFmWulJJKZnlR1OdzVSjBGKDzs8BMQ3e9vF+aruv6NIhyS2enia2fc0DsvL+4+I9TIK0ZnWZfYMiBrevTqa6rIpcR/ddmmvouLP7a9nppunEnAAAAKLKqPtcnzwGpsiwvy7oqMskx+DhkzTR0zeXt0vRjP0ypH3E8uvQ4sgcKKm683QQCSfmYAKeITAilAgpel0UA8Hz2wtA17z9//Pz59t4NXdsO02FHUOTV618vL3WRKxV20+V5kSvBfJoSkXOeAe/vzThNozZuL/YbSfCoBd1ZkGarA5YA4fxbNDlz0XrPoSABRVnXdV2WXgB88Hrs++by/uOff378fO+ncRj2GgAAALjK65d//f3XucwzJYQQUkiplFScIfmVQAgFXy/vjTbaGOMI2HodiJC6BLjEOR+bAmsWwMoKLFNgsb8BCfFWsCzzTDCG4Jwx0zh0XXN5+/nPv//5cRmsNlofIaAgVF6dXv/+3/96qfNMcS64YGFPLQMiR84aO+MgjXHej0qLqs39TYYa8wUfUoJrDuzKnO0+LBzI8rIsyyKXDBHBZy/0XXO9XN7ffv74+T6RcwdLdwBgWV5Wp5fXf/2v1zrPpGA85MP7vQDeD558OLFrmuBIIYOPU4HmE+l+wQzeRwENyLKiyJSUCH79pqeha5sYAGtMvDjJQCFHgEKVdX06n19eXl7qPJOcJXkgQOSsMXr0kcC+6/vgSDFKZuZBpx71j36FAT4ZSGVKKSnC4sbqse+ay+X9/f1ybdrej58xznkMrTEAAkSpyvNfry8vZx8Ek+uehzyAoe/7vh+GcYpqlOKRlJvTdKLW2iZVfUY3HKE4BT4UIx/NlVKpOZGT7DS0l/ef7+8/f75dmt7HP5kQUvJYJoAzBGRMZsXpX3///XquyjxXatM2OaOnoWvbpmm7fpx0VKNEHqGmYJKT8kFBOfvJ/6UdI/P4048fTjbmtwPP8K0zY3d9+/nP2/vbz7dLG8RWKKUyv9HN759HzoXMiur1r79fT9U6BjozIEylpu2GUSeBJM8Bn3I148pzFbkEabsrNrZ78kOYqF/ICREhAHJ6bK9v//z759v75f3aT2H8eZ7lSoSSCVIILqRUWVHU59fXc6nkbt8XOWvGoWuul8ul6YZpNX7CqOiInJsd0yj8EWqjr1qBQzpeFiEi44yLed55Cfjn3z98BBwAAHiWF2URLL3fQC2VUlmWq7Kq6zqXLElri/E5q6e+bS7v79emG3QaSKMQjQcARwDO22c/vyCpoHXfNHhcCc5IBkUOJPvWnJmG5vL248fl2vWjX7aqvKiqsswyqYSUQmZ+B3WeKZllRZ5LwcABEjkkQObREGf12LfX6/vl0nSD3kZBkDEEJHQRX8fIAFoQq7uUwCcMuOFu0yxgflE7u81kffzqcmkHHwBnWVFWdV2VReYdfplneZ4XWaGUzyhlDIgIHDkAYBw4ApC1xpuTy7XtR51EEgkhpmOTA3CzdfXl1RzBMvt/aQrcuHlGsqLyXR2FG63XMPr8Dc5VVvqEnyLLMqWkUnle5EWZ55lEZIwjEToIdXHQ+Y0vPqDStc312rbDZFy6ICGIAp88OvSDlrrE3+UILTmbMzacqrC4jQc5AONSqLysPF6aZz7mk+dFUZR5rgQAIJD14R/PAEZEhOSc9UhQ23bDEJMhN/3Am27PA8XrHk+TW/ZqLxxI/hr8Y4uGkAvpU55Op7os8ixTUoVN9GWW7xpMB+B8Tm3nUTDr9pncX6g+ekQHDKD159XhH2GhTPFNwzqfHZHLrKjPI2SjI8allCoPOiDPVaakkipTeZ4ViefjI2CAiORdJY8D6HHsu64bRm0sQdhDiuHRDplDWnZMRXlc7yO7hwEr1AsA1vGh1WzCmAzndYDfsrX2HJjMynM3YT5MDpgQKuZNl0WWKR/487HP5R7ngjFnfq9BCIqYaRz6rvMwYNAMsTPk0AGi31QcviVyiEDuoeEfSMDOfVj9hhDTZELgjmhV1ZTJoh405udJEzAhpFjC3SpUy5FCyuSxPtnVFxkl9EExnxM+DUPvx0+ELHm5ROTtHzkXSmkhgfNm8EEOxJ7E1x4KIaWRljkDNymBRqHa27qWFQCXxcmx7NRr47P+hZAqy7IsD4VyfNWoZP+PM9YG/575WmIw64Bx6PvBGGMBWcCH5vEDgo8ezW5vYADd5wHODNhNgZV6iR5VYNJc0A4gbN5zzrkk3MlV4Zgs+8k4AF8qTAgllVRCCO73vC55L2Stsz7MxwgZckCcw4JGT9MwjKO1znnF4Ii8xiTHHCEgJUUHo2t2kOX3MQO2X3jYZ+V5pkowljH1friPXRot4xVMFijz2pdCYIxxv1jgQnBfMm3eNkiO4t5PAmTgq8KEeoE+t2gax3GcHDlCjiwKuWd+8EdnDhAsuX6/pgNuuf2eAzFNyvci9HOaFgYIZKrQ1vrNQ74CHg9FEhdMxFprnXXkJxAhCiQAFippgbPWaN/ypImAWJx9iIgEhLTLCEsG/qsMuJc8ZDsNfZexLKyHUTA+GweciwSuoGqw1vg9guQIgBAZs0iLW+dsiKtqrbUBD26ms+/wNX/RK/il2CBZPfZtUwowEpl/w3Nm026F68k5MlYbbayvEueX1MsGHCKI715rE1KhHjRtj9DnDFg5XDH/0msCZ6a+uRTC6SIXUojUpKzJkbPOh8fDxldjHQFDzrkQgliM95NzzkzDGLcShzggkcNgA/cYzYcY6Wf0AQMSfHz2tjDCDd4hNGN3KThNQ13meY58P3J/nzU+ZdRoY4zVYc8vMC6kkkoRcE7emjtnzTQM/TAM6VZCIoewbJNOhxtg8a2/cvdi6PCyWEwg6tX1KR4elkREMkOjuNN9dz7VFm8x04St3tM4jXoK276NI19fJy8IQPjBETmrfTHJfhyTfHBygNEPWpdsxDja9Xd300FoLDYxG7yk+egXICIwcHpoGOih7brRgMgOmUmT1tM0jdM4DP0wjJOe9GSMBeIyK6rSECKzzttvHw3ugwtsAZgFACCyc4VynP+JXd0uk5aw1pfwgPn2eQpQDI354Qfw1UsAIzP0bTcY4llh+a418KMfh2EY+q7t+mGYJq21dYAiK6rJABdcWucBrJAP0HU+oZAWr8vzIXZwKR85v+37xrtnwPEUSH9up4Dfp4jIkMxIeuy7djAksrIyW3Abgjvj8f22bXySj9bGOWQqryZCoTJjnV9Z0pxZ2vl0YIxw++Jv7993+s1KbO9iwK0/rMJj87IgWrgAv9jJjEPf95p4Xlb1kK8rXRijJz0O49APXd+1bXNpmq4fJ62tI8azUpNQhTZxVUt+DTAOfd8PozYWYj2meUNoGl+Ky5OFLV+gx/yAyFQvA+DIIo7jqIlnVV23hapmeHh2E0efKdV1bXO9XD0DjCXiIjcoq8nYxAci54yexmEYB59THhkQDCAevd/vMoOfk49WGks8qy6nqpSgFQCgT5FcM6Dv2uZ6aZqun/zQhADuq6YgYwkHrNbTNI6T1vYXSwPcRV9hQMwWJCBHBI5Y310vVS5prCSE1Zyz1lqtJz2Ow9gPngNN0w7D5FUgoHUEyLmQPjnULzOc83zT2hjnVh4gYaKbn8UacSsI+GHafLBZ4L15q8e+vWScxmshvYfkyFkXa52M4zgMQz90rVeBxrlQaIELKbMsy5TgnPkaOBRu09ocSADOsfF1SsTXAcKIB6w9KVp9cWzeCWJVa3Jm7C6cdFNlAmekwnPA6MnzYBw9vuMr5SATUsosy/I8L/JM+cw4DEthY4zx9XPW+eRxE/F2496maw8xIPm8KaG0WnptGp2XYwiASHbqBOmuzDMfBI/5LdYav7DVgQ2TtoTIGSBXc5WEosyUDHOA4iI7MmD12Ij9x7LZh2JKj5nB5erkvs+EaobgABgwhmQGNP0lU4IhAjIKGS7OWmeNtca/VK2NdYAcAJlQS5GIUkoueIh3hvx4Y9OqKmH0UQccd+1LRw2sSmndGu7R+AEAyEf6GToNdhAx0R8ZYSz35fw60KNn1pKjAP761JJQYFj4+vp+ORhy5I21W3xzhmO8lN3d3zsYsPEjP2vCb89kQP48ACDrJpwJmIfrnI9eUyib5zdHQwCJpFKZ8iqgyIMdxCgA2njAIOXA0vzzTADcuV9gTwQA4NDDuQDOETlLRKGMFCD4aqBzpDosIRky8OnhcbuxUkqpGWkkinCY58D3ASGRfr1+AHmI3K72sd2oL4+Mh9yOsOeecc5FUi0vCMA0TZMx5kZ62VPpo7K6H1KEqhwjAuessbu/7u8hCBVCl7giY8iSC5y1evKkjyTg6ed2JWV1P2x24y0FJUjkGDkAZw/rOR3RDC75WrG0iq1ScAPHcS6ut743nCLwRIxQ3Cy1kq4/51XgxlMgAnIY697eST6MF+NKq/ucd5/HcRjDkml1H85yB08TA7G8/XvPaZ07hAQEbo5H3UcIsSJmLAeZ3OucNWaaxmEcJr9kWD8x/Fied5wW+Mi0Fneex7JfMMR1umfAXQ+bV+/kwM0VMZebiaw1Wo8+JKQ3uyNDF2YnlHa+y37l8Hm/EiT75sWYfkyuIoyTcuex3hBRFuvBIiwVQRNV722A9uPXa0cw8U0/1oMzWnb7koTuswLHOmBGCWHZC+3Dm1HON60kboxDa60voZi+ZbLWeAh18rUS149b3v1HLLhdFvCA7vUDbqyZ56UHxVpk84qNiFaFNuJ8JS84wGNo2NGy7nLRDnjccDPKRQKe5ws+wRECgCU8gcdqaXtHmAM7MwBEHknSxt5vWn+FnlVSM9FIs49wdJm3mwAUGLA2oAgBETGhWuzODwr/fhkD3dFdDFhVkPwAh0hrER6Y6rnstdcac9Hg5TJ/lpyzdi6kejwHPpwBuHkVH9NxXGDd3lqhfmAv0sNf94feBh3mQm6r4dtyscFMkgdS3HGuC82Nzc2udeKDwiHm7t68ZN3iJx5DrOsMRzWtaMa4CZBCjsCqiATzaaLW+sypjx50u7MzaHKPa/fB5un01a8Ny/HKwecxJLZ6h6MhAVB0EhzfqcGQI0bObeHA5Bnzv77JWe7SmQRwh2/nKU6B274FrtnwwR77BayMumqnBJY0a0K3NwM4Hzd6MySQVP0KTW77G4sd3VdUc946e69h/ZCxtP77rRlMAB4w2uZ1zqUzduZxGefOMGx0wIP0uRJ8kD5sj9Ify9CXg/L8jtkFSprpeXZvQ7//2N1lxt9ymR5N9/0l+h0MOJBeREDujwrZ6pR0+AuDPlqs/RKFKfClg9qOBHO/Il0uW7IagCEy4bMnU10V86DmmPedDshRH+7k15Ii85GHd/RH3H3w/jBtr1lOf5nNKDLGWCgZhGkQbjkwBDyf0pEcd2T37h6rqpfUEbpldxZbk16Dm5/h87aTIZiDAL4iAgDE3XbKHyC5YkCUAIZIQUssi+DD4UP0u+Zv8DEJuH0drd/tEaUCHm/ZT4KQKLq4Zcj8AYSZ9GdtzbdEhAyRMUc+wkDRe9hmR93q9FLv4R4mJAUVj9vbzfPUHYT1PA2Pnk8miXEsjLv6w/fIuBAqy5QvHE8rCSAI4aZoJxIYZHeo8SGtsv0/o4MaIrvmAA6mACxfb39dTwHEqNUwbvFjjAkfEvInza50AMC8JwyCBKz78gklrupdOuARxXpEeMdECRY/cXeCClAiBoVDAzFSMHPsiKlf7usRsc8veT75I6ilT41J9w3vPYBv8wAj/QkGoC+9IGVMDpr/8hs9wEh3eoKLcX6c5gVysuz1oWEl5f7IvBQkDv/SYb3r55C4Z2i0+XnrKlrDAJvjOpJwDuNcCjXPgbUSWDa/I1HUvA/Z9kfoEyuwZU/y+3qyHurdZAuiH0T4M/MnycfimckRMvFn2AxHFD8HvjyfAx+lyNyUucQypi5YhDrWl87VOtOAHg8HawohVhIwN5KMeWbAsR1+BAE9oGfpgKV3tPo2ui4ECUNDYoSM5yfj6hZarowMcMk0iG3c1/FP6S4GHOqAjb06ugRmCQifZw/RJ8nwcJLwwWBm6V8k4Kj5X6dfxgOWnt1YTtHiwM8S4JeDLGynW+X8J45VEHvaCdYz6bcgQvd3HtcQwE4Cnk/fwYCb+OVMlFJ655wIt7Twza7RExiwllBcq4XjWRFiHyETcLkmDH/ZIBkM7qc8+DqTnqADNnZoPo8lmL9932gJAa+TQMIyMFkC3SMB/gV81Sr8GgNovyLGWNMpOkAHLCDrQka8WefEI4tQwPKAGynRR4tv3Px+D33EgDuDS9urGDCIfs+xBPhqudMUUuGWBAm/2Xq3Cj7qx9Y7/eokOECEHhOmdWxq2TEN5NDRIYzlt0dO4+TzYFZnzXmzmHDg47ewEsEIOD3EjCcmSISnz4qcUThbficGIRNozoZceI8zC7YNr2k9SyDVAQ8qg2dZgZlw3udOt5LafX0IERiQ5oN6+V87R0e0Xj+vv36iDvgqRUDrthvgQjLYFA4OXu5MPMTvtf+Rfn9sECDujPIMCEVUPGFCAN/nAC/04HmDkT7uGC05AgctYCgbz4MZsC4tlJUwIGg42twPH+u5ByXnPgbsJuTtRThROPPuIMYdk06InDVMj+O0ywfFhQUxAzE5jTD8M+v+zQIZw/U3unZI3yABM4BDWxAnFr0kcmg1Mn9YQMqBEEZKDIGvapReEPfT7nysLzmDH1WRweTHPgQSYKGNgQMkABdTRbdNznFSchYAox1cp8NgWC2HMtZufsZcyivK+cHkgDlR6c7o0EcSsLv9SAL2LAjk5rT2BDoJ4/cHKEGwg8bauRtxHw1nLLziNPNqXidtT4f0bccp8HxHCA8kAFY9SCiMOdY42mqBWPbSOZqmaAgWLcjCPiLhrK8ogAusNC8zPlGCjxiQp+uAZV/HfjE/z2sgtBgYME2T1vPZRB4vVVIqiruLElgtWSo/y034Bj9gxr+O1vGh8r3fexrKa4yjimfz+OOmVJZrB8w5cs4xl4r67CI+y0P4I44QgJ8k4QiZvheCCb8E4kJlWVH2Fhg3ztpv94W+jQH39NsfI9w2EhkozjkC4zIrqrrXyMUwGb3fi7IBzPbP/KbF0HZFe3jRbg/+5q+40abOF+HNOSA4qZABk1lZnXvtuBBiGIHsJmtsYQAtxjg+zSu/7/AEjxL0j13dm7uaZ28iMWv+HNlCIiFQzjgAl3l5GicLUkrBwFkESLcezFVUfCnqJB0mZcYj0+ahKfBRu4k3Hz7f4ETKTKeH7ppxXy1bSAAUWTVqQ0xJwcAZzTZ70jBhJa2ywSlKwWPbKe5gAM2yvTeDd65Jtl5ElACrxzaTDAgZ58oCgMxKbS0w5cc/8TR0jMscCD5POv4lNWjnJn5AT5OA0Ku5xj3tkvUPjgV3euiUYEBcSFVYx4ArZ5xDJhk4M43jyBlbZeelShA36VCzx/0ARnyvDvANb7+4QUePT+MF8x47Z6ah5QxQqryotGUAkqxzgAKcnoZ+HI1xRG5TPm2WgOOOPBJIeqoEPErk7DQIzpGrvOrHyUgAFKqwDtDqcej7cbTOLyu+CyD6Y44QAAA4a6ZRSK6KrhuGcZKIiFxm1jlfUGvSBIgW7XYT5tPojzKAyBozyUGEs1R6nz7OpTJ2God+GI1jjHONYMMNYX21CSlumn2kD3+YAc4ZM8lpHIe+63LJhBDcAZOZ09M4akMohOADUKi/mmIMaV518DD8OuOh2fJUBkRdeaSEoxJcYUSeA2aaxqHv2kyiVFIgMS5tXk7aWORSCM7IriVgri5Jc8PBy7yJRt6ib5aAnfnY/DkcrOoZoDgooxRHYlyZwlgHXCgpEIzWvjVywACWdfJSX2zeKvVgLtXTGDD74ssX+30fu3t8nNyfT9dKBpmxTnIETqpwDlBIJRi5aeSYNLrsKgqSj7ODMCck3UvfIQEfmKyt50IUz9Tqu4wjFdY5JwUik44AmRCSkZn6joXbA+K+QtswQskQHcEHtMBv0wHJNUnvyNcO09PYK8mRwpZ5jkwAIHIhBNixbxUP1VVxlgAI7AwKYIGR4eDIitv0h63AXDhoHKTkQJa8WuPIBDBknDE7dW0m55pnCwMSqDllwGOY6B92hMLx2lpMg+SCATgAZJwx5MgRERnQ1Be5kpw/OLB76c8ywGsAxibBBePgowWMM4bIGDLOOMLUVf4w5rifaNsILpVXH4ZDfi8DdmtoIgcOrWFsZMzXUiSaUwQ4khAcpq4qizyT2jmv99cL7Xl1TKFJovT3T+nZDJgPJt394dAcErmYUeEtIoE/uwuRM45Scphaf0qVNugIw2kz4XZcAiIxIJs6iPfw4PkSsM2Z+fhacuBPEnTOGWOML7PKOeOEgqPlMDZVWeSZmsACQdh/trTv9xZRXAOHPTd/yAxGiokxO9do84FgqUtqnTXGWAImhOAMuSTkyJktyiLPskxJN0v3sv9u+UEADhJgLAGyPqRv1AG0+nhDFMihdwiNNVobS4wLIThjXCpCf1aj8odzm7CnjiIgCtEBgMQy0lpCPqXf5gnS4awkACIE5yy3RltLyIUUgnPkUirBAfyJjkIIzi1zCBB23yWhYt/Ssqtgoxc/pj9hBTbfEIIjZ6wxzhJyKYUSHBkXgmWMnCPyx9QzP/6bw1r5Rn/MCjxMQXcjOu6cQy6VEpIhAWNIllOoqvZwxOde+uMMAPDaC8g5AsalEIyD09roqc85jJf3ph0mn1HqJWZfYMvLfHSF8RF3+HsYsFsbf0YOAIgAGRecA+mx67rmWmQcpvbnv9+u/WjmkktEi7GPGp8WhxBn83MXE77HDM7/7L+GFa6d/NURoWYjZ2DN0BVFVZZZxkF3l58/3q79aEw4aQzSpXDaKLvR8of0PVZg983mt3jBylwQOavRH97TZlleZJlkYMb2cnm7duNcbX+VfRZK02GUAKS5zOF9uTLPY8CRo/vpxWsOOLQaweqx8/WGpUBw09C1TduN2sRkqogGptDHnGC93rv2YCWpZ9FnUNjN+5z10SIhBPfuADjjD52LNWZpjanQfAbMnFwaveH7/OHvtAIfduAwzEaWnDXMH9LtN9WS9aW2k7JzyytexNxnFwIAoQM4iEzeov8IMxiJCMnZmCkbth2QP5gvzTpNPni9HzEx9I4lPQCK/ScxgD4Rmpv3BQdgxgb+n8EEn0wpMn43/ZEKEs+mGRz39GlR05T+Z0jAbAvhYS/0f4QEBMLk37vp/wK1FNcBypIAsQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=256x256>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "!pip install -q pillow\n",
    "\n",
    "\n",
    "import base64\n",
    "from PIL import Image\n",
    "from IPython import display\n",
    "import io\n",
    "\n",
    "\n",
    "# raw_data是一张MNIST图片，对应数字9\n",
    "raw_data = base64.b64decode(b\"/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+rVhpmoarP5GnWNzeTYz5dvE0jfkoJovNMv8ATmK3tjc2zByhE8TIQw6jkdR6VVq9oumPrWuWGlxyLG95cRwK7dFLMFyfzr3aXwp4ltAfB3gWwudI01JNuoa7eZhku5AMHafvFOw2Dn6ZJ4z4yeLk1HUbXwrZSSy2Oh5heeaQu88wG1mLHk4wR9c+1eXUqsVYMpIIOQR2r1D4QazqOs/FnSG1fVLi9ZI5vL+2TNKc+U2ApYnB7/hXml5LLNfXEsxLSvIzOSMEsTk1DRVnT7+60vULe/spmhureQSRSL1Vh0NWNd1mXX9ZuNUuLe2gmuCGkS2QohbABbBJwTjJ9yelZ1f/2Q==\")\n",
    "\n",
    "im = Image.open(io.BytesIO(raw_data)).resize((256, 256))\n",
    "\n",
    "display.display(im)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "推理服务使用 HTTP 请求体内的数据作为输入的图片，SDK 的 `raw_predict` 方法接受 bytes 数据类型的请求，通过 POST 方法，在请求内带上用户推理数据，发送给到推理服务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from pai.predictor import RawResponse\n",
    "\n",
    "resp: RawResponse = predictor.raw_predict(data=raw_data)\n",
    "print(resp.json())\n",
    "\n",
    "print(np.argmax(resp.json()))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试完成之后可以删除服务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.delete_service()"
   ]
  }
 ],
 "metadata": {
  "execution": {
   "timeout": 1800
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
